import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam, SGD
from torch.autograd import Variable
import torchvision

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.svm import SVC
from sklearn.metrics import classification_report, confusion_matrix,accuracy_score
import seaborn as sns
from scipy.ndimage import rotate
from scipy.ndimage import rotate

import matplotlib.pyplot as plt

from model import ArtNet

from utils import *

model = ArtNet(11)
model_path = "/home/shivam-wiz/Downloads/MLPR___/Trial/best_checkpoint.model"  
model.load_state_dict(torch.load(model_path))
model.eval()
test_loader = DataLoader(
    torchvision.datasets.ImageFolder(test_path, transform=transformer),
    batch_size=32, shuffle=True
)

test_accuracy = 0.0
predictions = []
actual_labels = []
for i, (images, labels) in enumerate(test_loader):
    if torch.cuda.is_available():
        images = Variable(images.cuda())
        labels = Variable(labels.cuda())

    outputs = model(images)
    _, prediction = torch.max(outputs.data, 1)
    test_accuracy += int(torch.sum(prediction == labels.data))
    predictions.extend(prediction.tolist())
    actual_labels.extend(labels.tolist())

test_accuracy = test_accuracy / test_count

print(f"The accuracy of the model is {round(test_accuracy * 100)}%")
print()

report = classification_report(actual_labels, predictions, target_names=classes)
print(report)

conf_matrix = confusion_matrix(actual_labels, predictions)
sns.set()
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", cbar=True)
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.title("Confusion Matrix")
plt.show()